//===--- FloatingPointDifferentiation.swift.gyb ---------------*- swift -*-===//
//
// This source file is part of the Swift.org open source project
//
// Copyright (c) 2020 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
//
//===----------------------------------------------------------------------===//

import Swift
import SwiftShims

% from SwiftFloatingPointTypes import all_floating_point_types
% for self_type in all_floating_point_types():
%{
Self = self_type.stdlib_name
bits = self_type.bits

def Availability(bits):
    if bits == 16:
        return '@available(macOS 9999, iOS 9999, tvOS 9999, watchOS 9999, *)'
    return ''
}%

% if bits == 80:
#if !(os(Windows) || os(Android)) && (arch(i386) || arch(x86_64))
% end

//===----------------------------------------------------------------------===//
// Protocol conformances
//===----------------------------------------------------------------------===//

${Availability(bits)}
extension ${Self}: Differentiable {
  ${Availability(bits)}
  public typealias TangentVector = ${Self}

  ${Availability(bits)}
  public mutating func move(along direction: TangentVector) {
    self += direction
  }
}

//===----------------------------------------------------------------------===//
// Derivatives
//===----------------------------------------------------------------------===//

/// Derivatives of standard unary operators.
${Availability(bits)}
extension ${Self} {
  @usableFromInline
  @_transparent
  @derivative(of: -)
  static func _vjpNegate(x: ${Self})
    -> (value: ${Self}, pullback: (${Self}) -> ${Self}) {
    return (-x, { v in -v })
  }

  @usableFromInline
  @_transparent
  @derivative(of: -)
  static func _jvpNegate(x: ${Self})
  -> (value: ${Self}, differential: (${Self}) -> ${Self}) {
    return (-x, { dx in -dx })
  }
}

/// Derivatives of standard binary operators.
${Availability(bits)}
extension ${Self} {
  ${Availability(bits)}
  @inlinable // FIXME(sil-serialize-all)
  @_transparent
  @derivative(of: +)
  static func _vjpAdd(
    lhs: ${Self}, rhs: ${Self}
  ) -> (value: ${Self}, pullback: (${Self}) -> (${Self}, ${Self})) {
    return (lhs + rhs, { v in (v, v) })
  }

  ${Availability(bits)}
  @inlinable // FIXME(sil-serialize-all)
  @_transparent
  @derivative(of: +)
  static func _jvpAdd(
    lhs: ${Self}, rhs: ${Self}
  ) -> (value: ${Self}, differential: (${Self}, ${Self}) -> ${Self}) {
    return (lhs + rhs, { (dlhs, drhs) in dlhs + drhs })
  }

  ${Availability(bits)}
  @inlinable // FIXME(sil-serialize-all)
  @_transparent
  @derivative(of: +=)
  static func _vjpAddAssign(_ lhs: inout ${Self}, _ rhs: ${Self}) -> (
    value: Void, pullback: (inout ${Self}) -> ${Self}
  ) {
    lhs += rhs
    return ((), { v in v })
  }

  ${Availability(bits)}
  @inlinable // FIXME(sil-serialize-all)
  @_transparent
  @derivative(of: +=)
  static func _jvpAddAssign(_ lhs: inout ${Self}, _ rhs: ${Self}) -> (
    value: Void, differential: (inout ${Self}, ${Self}) -> Void
  ) {
    lhs += rhs
    return ((), { $0 += $1 })
  }

  ${Availability(bits)}
  @inlinable // FIXME(sil-serialize-all)
  @_transparent
  @derivative(of: -)
  static func _vjpSubtract(
    lhs: ${Self}, rhs: ${Self}
  ) -> (value: ${Self}, pullback: (${Self}) -> (${Self}, ${Self})) {
    return (lhs - rhs, { v in (v, -v) })
  }

  ${Availability(bits)}
  @inlinable // FIXME(sil-serialize-all)
  @_transparent
  @derivative(of: -)
  static func _jvpSubtract(
    lhs: ${Self}, rhs: ${Self}
  ) -> (value: ${Self}, differential: (${Self}, ${Self}) -> ${Self}) {
    return (lhs - rhs, { (dlhs, drhs) in dlhs - drhs })
  }

  ${Availability(bits)}
  @inlinable // FIXME(sil-serialize-all)
  @_transparent
  @derivative(of: -=)
  static func _vjpSubtractAssign(_ lhs: inout ${Self}, _ rhs: ${Self}) -> (
    value: Void, pullback: (inout ${Self}) -> ${Self}
  ) {
    lhs -= rhs
    return ((), { v in -v })
  }

  ${Availability(bits)}
  @inlinable // FIXME(sil-serialize-all)
  @_transparent
  @derivative(of: -=)
  static func _jvpSubtractAssign(_ lhs: inout ${Self}, _ rhs: ${Self}) -> (
    value: Void, differential: (inout ${Self}, ${Self}) -> Void
  ) {
    lhs -= rhs
    return ((), { $0 -= $1 })
  }

  ${Availability(bits)}
  @inlinable // FIXME(sil-serialize-all)
  @_transparent
  @derivative(of: *)
  static func _vjpMultiply(
    lhs: ${Self}, rhs: ${Self}
  ) -> (value: ${Self}, pullback: (${Self}) -> (${Self}, ${Self})) {
    return (lhs * rhs, { v in (rhs * v, lhs * v) })
  }

  ${Availability(bits)}
  @inlinable // FIXME(sil-serialize-all)
  @_transparent
  @derivative(of: *)
  static func _jvpMultiply(
    lhs: ${Self}, rhs: ${Self}
  ) -> (value: ${Self}, differential: (${Self}, ${Self}) -> ${Self}) {
    return (lhs * rhs, { (dlhs, drhs) in lhs * drhs + rhs * dlhs })
  }

  ${Availability(bits)}
  @inlinable // FIXME(sil-serialize-all)
  @_transparent
  @derivative(of: *=)
  static func _vjpMultiplyAssign(_ lhs: inout ${Self}, _ rhs: ${Self}) -> (
    value: Void, pullback: (inout ${Self}) -> ${Self}
  ) {
    defer { lhs *= rhs }
    return ((), { [lhs = lhs] v in
      let drhs = lhs * v
      v *= rhs
      return drhs
    })
  }

  ${Availability(bits)}
  @inlinable // FIXME(sil-serialize-all)
  @_transparent
  @derivative(of: *=)
  static func _jvpMultiplyAssign(_ lhs: inout ${Self}, _ rhs: ${Self}) -> (
    value: Void, differential: (inout ${Self}, ${Self}) -> Void
  ) {
    lhs *= rhs
    return ((), { $0 *= $1 })
  }

  ${Availability(bits)}
  @inlinable // FIXME(sil-serialize-all)
  @_transparent
  @derivative(of: /)
  static func _vjpDivide(
    lhs: ${Self}, rhs: ${Self}
  ) -> (value: ${Self}, pullback: (${Self}) -> (${Self}, ${Self})) {
    return (lhs / rhs, { v in (v / rhs, -lhs / (rhs * rhs) * v) })
  }

  ${Availability(bits)}
  @inlinable // FIXME(sil-serialize-all)
  @_transparent
  @derivative(of: /)
  static func _jvpDivide(
    lhs: ${Self}, rhs: ${Self}
  ) -> (value: ${Self}, differential: (${Self}, ${Self}) -> ${Self}) {
    return (lhs / rhs, { (dlhs, drhs) in dlhs / rhs - lhs / (rhs * rhs) * drhs })
  }

  ${Availability(bits)}
  @inlinable // FIXME(sil-serialize-all)
  @_transparent
  @derivative(of: /=)
  static func _vjpDivideAssign(_ lhs: inout ${Self}, _ rhs: ${Self}) -> (
    value: Void, pullback: (inout ${Self}) -> ${Self}
  ) {
    defer { lhs /= rhs }
    return ((), { [lhs = lhs] v in
      let drhs = -lhs / (rhs * rhs) * v
      v /= rhs
      return drhs
    })
  }

  ${Availability(bits)}
  @inlinable // FIXME(sil-serialize-all)
  @_transparent
  @derivative(of: /=)
  static func _jvpDivideAssign(_ lhs: inout ${Self}, _ rhs: ${Self}) -> (
    value: Void, differential: (inout ${Self}, ${Self}) -> Void
  ) {
    lhs /= rhs
    return ((), { $0 /= $1 })
  }
}

% if bits == 80:
#endif
% end
% end

extension FloatingPoint
where
  Self: Differentiable,
  Self == Self.TangentVector
{
  @inlinable // FIXME(sil-serialize-all)
  @derivative(of: addingProduct)
  func _vjpAddingProduct(
    _ lhs: Self, _ rhs: Self
  ) -> (value: Self, pullback: (Self) -> (Self, Self, Self)) {
    return (addingProduct(lhs, rhs), { _ in (1, rhs, lhs) })
  }

  @inlinable // FIXME(sil-serialize-all)
  @derivative(of: squareRoot)
  func _vjpSquareRoot() -> (value: Self, pullback: (Self) -> Self) {
    let y = squareRoot()
    return (y, { v in v / (2 * y) })
  }
}
